import numpy as np
import netCDF4
import matplotlib.pyplot as plt
import matplotlib as mpl
import netCDF4
import matplotlib.pyplot as plt
import os
from datetime import datetime
from warnings import filterwarnings
filterwarnings(action='ignore', category=DeprecationWarning, message='`np.bool` is a deprecated alias')

### function for returning index of nearest neighbor ###
# dd = value to search nearest neighbor for
# dd_array = array in which to search for nearest neighbor
def geo_idx(dd, dd_array):
   geo_idx = (np.abs(dd_array - dd)).argmin()
   return geo_idx

### import ZMIN detection limit of KAZR at each model level ###
ZMIN = np.loadtxt("data_in/z_min_model_levels.txt", delimiter = "\t", skiprows = 0)

### data files depending on day ###
day = 28
# start hour
# = 0 for 12 March and 28 March (simulations start 12 March 12Z and 28 March 00Z)
# = 12 for 13 March
# = 24 for 29 March

# 12 March
if day == 12:
    file_kazr = 'data_in/anxarsclkazr1kolliasM1.c1.20200312.000000.nc'
    file_interp = 'data_in/anxinterpolatedsondeM1.c1.20200312.000030.nc'
    file_extract = 'data_in/extracted_crsim_data_031220.nc'
    start_hour = 0

if day == 13:
    file_kazr = 'data_in/anxarsclkazr1kolliasM1.c1.20200313.000000.nc'
    file_interp = 'data_in/anxinterpolatedsondeM1.c1.20200313.000030.nc'
    file_extract = 'data_in/extracted_crsim_data_031220.nc' # only one file per simulation
    start_hour = 12
# 28 March
if day == 28:
    file_kazr = 'data_in/anxarsclkazr1kolliasM1.c1.20200328.000000.nc'
    file_interp = 'data_in/anxinterpolatedsondeM1.c1.20200328.000030.nc'
    file_extract = 'data_in/extracted_crsim_data_032820.nc'
    start_hour = 0
# 29 March
if day == 29:
    file_kazr = 'data_in/anxarsclkazr1kolliasM1.c1.20200329.000000.nc'
    file_interp = 'data_in/anxinterpolatedsondeM1.c1.20200329.000030.nc'
    file_extract = 'data_in/extracted_crsim_data_032820.nc' # only one file per simulation
    start_hour = 24

### Initialize time, and reflectivity and temperature arrays for each simulation ###
start_time, end_time = 12, 18 # hour of day at which transect starts and ends ###
start_time_str, end_time_str = str(start_time), str(end_time)
if start_time < 10:
    start_time_str = "0"+start_time_str
if end_time < 10:
    end_time_str = "0"+end_time_str

dates_str = file_interp[-12:-10]
### import the data ###
f = netCDF4.Dataset(file_extract)
Z1 = np.array(f.variables['Z_no_edmf'][:])[(start_hour+start_time)*600:(start_hour+end_time)*600]
Z2 = np.array(f.variables['Z_mynn_l3'][:])[(start_hour+start_time)*600:(start_hour+end_time)*600]
Z3 = np.array(f.variables['Z'][:])[(start_hour+start_time)*600:(start_hour+end_time)*600]
Z4 = np.array(f.variables['Z_ysu'][:])[(start_hour+start_time)*600:(start_hour+end_time)*600]
T1 = np.array(f.variables['T_no_edmf'][:])[(start_hour+start_time)*600:(start_hour+end_time)*600]
T2 = np.array(f.variables['T_mynn_l3'][:])[(start_hour+start_time)*600:(start_hour+end_time)*600]
T3 = np.array(f.variables['T'][:])[(start_hour+start_time)*600:(start_hour+end_time)*600]
T4 = np.array(f.variables['T_ysu'][:])[(start_hour+start_time)*600:(start_hour+end_time)*600]
h = np.array(f.variables['Height'][:])
time = np.array(f.variables['Time'][:])[(start_hour+start_time)*600:(start_hour+end_time)*600]
f.close()

f = netCDF4.Dataset(file_kazr)
Z = np.array(f.variables['reflectivity_best_estimate'][:])[start_time*900:end_time*900+1]
hkazr = np.array(f.variables['height'][:])
timekazr = np.array(f.variables['time'][:])[start_time*900:end_time*900+1]
f.close()

### remove invalid data, the KAZR did not observe any values outside this range ###
invalidkazr = np.logical_or(Z>35,Z<-50)

### Calculate average KAZR reflectivity in time windows ###
Zlin = 10**(Z/10)
Zlin[invalidkazr] = 0
ts_ave = 300 # averaging window in seconds 
ave = int(ts_ave/8) # convert averaging window to time steps 
mins = "_05min"
time_ave = timekazr
Z_ave = []
for i in range(0,len(time_ave)):
    idx = geo_idx(time_ave[i],timekazr)
    if idx < ave:
        Z_ave.append(10*np.log10(np.nanmean(Zlin[0:idx+ave+1],axis=0)))
    if idx > len(time_ave)-ave:
        Z_ave.append(10*np.log10(np.nanmean(Zlin[idx-ave:len(time_ave)],axis=0)))
    if idx >= ave and idx <= len(time_ave)-ave:
        Z_ave.append(10*np.log10(np.nanmean(Zlin[idx-ave:idx+ave+1],axis=0)))
Z_ave = np.asarray(Z_ave)

### set invalid values as nans ###
Z[invalidkazr] = np.nan

### Import interpolated sonde data ### 
f = netCDF4.Dataset(file_interp)
theta = np.array(f.variables['potential_temp'][:])[start_time*60:end_time*60+1]
hinterp = np.array(f.variables['height'][:])*1000
timeinterp = np.array(f.variables['time'][:])[start_time*60:end_time*60+1]
f.close()
invalidinterp = np.logical_or(theta>4000,theta<0)
theta[invalidinterp] = np.nan

### time stamps ###
time_surface = []
for i in range(0,len(time)):
    dt_object = datetime.utcfromtimestamp(time[i]+6)
    time_surface.append(dt_object.strftime('%H:%M UTC\n%d %b %Y'))
base_time_start = time[0]


### Figure plotting ###

### Figure parameters ###
time = np.asarray(time) - base_time_start
ylim = 6000
yticks = range(0,6100,1000)
xticks = np.linspace(0,max(time),9)
xticklabels = [time_surface[int((len(time_surface)-1)*0.125*i)] for i in range(0,9)]
yticklabels = range(0,7,1)
y = h

### Color bar set up for plotting ###
vmin = -50
vmax = 25
num = vmax - vmin + 1
cmap = plt.get_cmap('nipy_spectral')
bounds = np.linspace(vmin,vmax,num=num)
colorlist  = [cmap(i) for i in range(cmap.N)]
cmap = mpl.colors.LinearSegmentedColormap.from_list('Custom cmap', colorlist, cmap.N)
norm = mpl.colors.BoundaryNorm(bounds, cmap.N)

fig, ax = plt.subplots(6,1,figsize=(12,10/6*6))
circle = ax[0].pcolormesh(time, y, np.transpose(Z1),cmap=cmap,norm=norm)
ax[0].set_ylim(0,6000)
ax[0].set_xlim(0,max(time))

ax[0].set_ylabel("MSL [km]")
ax[0].set_yticks(yticks)
ax[0].set_xticks(xticks)
ax[0].set_yticklabels(yticklabels)
ax[0].get_xaxis().set_ticklabels([])
ax[0].tick_params(labelbottom=False,labelleft=True, labelright=True,bottom=True, top=True, left=True, right=True)
ax[0].tick_params('both', length=6, width=1, which='major')
ax[0].set_title(" (a)  Level 2.5", loc="left", pad=3.25)
cntr = ax[0].contour(time,y,np.transpose(T1),np.round(np.arange(-100,300,2)), colors='magenta', linewidths=0.8, linestyles ="dashed", zorder=6)
plt.clabel(cntr,inline=1, inline_spacing=8, fontsize=7, fmt='%i')

circle = ax[1].pcolormesh(time, y, np.transpose(Z2),cmap=cmap,norm=norm)
ax[1].set_ylim(0,6000)
ax[1].set_xlim(0,max(time))
ax[1].set_ylabel("MSL [km]")
ax[1].set_yticks(yticks)
ax[1].set_xticks(xticks)
ax[1].set_yticklabels(yticklabels)
ax[1].get_xaxis().set_ticklabels([])
ax[1].tick_params(labelbottom=False,labelleft=True, labelright=True,bottom=True, top=True, left=True, right=True)
ax[1].tick_params('both', length=6, width=1, which='major')
ax[1].set_title(" (b)  Level 3", loc="left", pad=3.25)
cntr = ax[1].contour(time,y,np.transpose(T2),np.round(np.arange(-100,300,2)), colors='magenta', linewidths=0.8, linestyles ="dashed", zorder=6)
plt.clabel(cntr,inline=1, inline_spacing=8, fontsize=7, fmt='%i')

circle = ax[2].pcolormesh(time, y, np.transpose(Z3),cmap=cmap,norm=norm)
ax[2].set_ylim(0,6000)
ax[2].set_xlim(0,max(time))
ax[2].set_ylabel("MSL [km]")
ax[2].set_yticks(yticks)
ax[2].set_xticks(xticks)
ax[2].set_yticklabels(yticklabels)
ax[2].get_xaxis().set_ticklabels([])
ax[2].tick_params(labelbottom=False,labelleft=True, labelright=True,bottom=True, top=True, left=True, right=True)
ax[2].tick_params('both', length=6, width=1, which='major')
ax[2].set_title(" (c)  Level 2.5 EDMF", loc="left", pad=3.25)
cntr = ax[2].contour(time,y,np.transpose(T3),np.round(np.arange(-100,300,2)), colors='magenta', linewidths=0.8, linestyles ="dashed", zorder=6)
plt.clabel(cntr,inline=1, inline_spacing=8, fontsize=7, fmt='%i')

circle = ax[3].pcolormesh(time, y, np.transpose(Z4),cmap=cmap,norm=norm)
ax[3].set_ylim(0,6000)
ax[3].set_xlim(0,max(time))
ax[3].set_ylabel("MSL [km]")
ax[3].set_yticks(yticks)
ax[3].set_xticks(xticks)
ax[3].set_yticklabels(yticklabels)
ax[3].get_xaxis().set_ticklabels([])
ax[3].tick_params(labelbottom=False,labelleft=True, labelright=True,bottom=True, top=True, left=True, right=True)
ax[3].tick_params('both', length=6, width=1, which='major')
ax[3].set_title(" (d)  YSU", loc="left", pad=3.25)
cntr = ax[3].contour(time,y,np.transpose(T4),np.round(np.arange(-100,300,2)), colors='magenta', linewidths=0.8, linestyles ="dashed", zorder=6)
plt.clabel(cntr,inline=1, inline_spacing=8, fontsize=7, fmt='%i')

xtickskazr = np.linspace(min(timekazr),max(timekazr),9)
circle = ax[4].pcolormesh(time_ave, hkazr, np.transpose(Z_ave),cmap=cmap,norm=norm)
ax[4].set_ylim(0,6000)
ax[4].set_xlim(min(timekazr),max(timekazr))
ax[4].set_ylabel("MSL [km]")
ax[4].set_yticks(yticks)
ax[4].set_xticks(xtickskazr)
ax[4].set_yticklabels(yticklabels)
ax[4].get_xaxis().set_ticklabels([])
ax[4].tick_params(labelbottom=False,labelleft=True, labelright=True,bottom=True, top=True, left=True, right=True)
ax[4].tick_params('both', length=6, width=1, which='major')
ax[4].set_title(" (e)  Andenes AMF1 Averaged", loc="left", pad=3.25)
cntr = ax[4].contour(timeinterp,hinterp,np.transpose(theta),np.round(np.arange(-100,300,2)), colors='magenta', linewidths=0.8, linestyles ="dashed", zorder=6)
plt.clabel(cntr,inline=1, inline_spacing=8, fontsize=7, fmt='%i')

circle = ax[5].pcolormesh(timekazr, hkazr, np.transpose(Z),cmap=cmap,norm=norm)
ax[5].set_ylim(0,6000)
ax[5].set_xlim(min(timekazr),max(timekazr))
ax[5].set_ylabel("MSL [km]")
ax[5].set_yticks(yticks)
ax[5].set_xticks(xtickskazr)
ax[5].set_yticklabels(yticklabels)
ax[5].get_xaxis().set_ticklabels(xticklabels)
ax[5].tick_params(labelbottom=True,labelleft=True, labelright=True,bottom=True, top=True, left=True, right=True)
ax[5].tick_params('both', length=6, width=1, which='major')
ax[5].set_title(" (f)  Andenes AMF1", loc="left", pad=3.25)
cntr = ax[5].contour(timeinterp,hinterp,np.transpose(theta),np.round(np.arange(-100,300,2)), colors='magenta', linewidths=0.8, linestyles ="dashed", zorder=6)
plt.clabel(cntr,inline=1, inline_spacing=8, fontsize=7, fmt='%i')
cax = fig.add_axes([0.94,0.1,0.025,0.8])
cbar = fig.colorbar(circle, norm=norm, cax=cax,extend= "neither")
cbar.set_ticks(bounds[::5])
cbar.set_ticklabels(bounds[::5])
cbar.ax.tick_params(labelsize=15)
cbar.set_label("Reflectivity [dBZ]", fontsize=15)


os.makedirs("figures", exist_ok=True)
os.makedirs("data_out", exist_ok=True)
plt.savefig("figures/vert_transect_03"+dates_str+"20_"+start_time_str+"-"+end_time_str+mins+".png", dpi = 300 ,bbox_inches='tight')
plt.close()


### save column maximum reflectivity to file ###
f = netCDF4.Dataset("data_out/max_ref_data_03"+dates_str+"20_"+start_time_str+"-"+end_time_str+mins+'.nc','w', format='NETCDF4')

f.createDimension('Time_model', None)
f.createDimension('Time_amf1', None)
time_model = f.createVariable('Time_model', 'i4', 'Time_model')
time_amf1 = f.createVariable('Time_amf1', 'i4', 'Time_amf1')
ref1 = f.createVariable('max_ref_Level_2_5', 'f4', 'Time_model')
ref2 = f.createVariable('max_ref_Level_3', 'f4', 'Time_model')
ref3 = f.createVariable('max_ref_Level_2_5_EDMF', 'f4', 'Time_model')
ref4 = f.createVariable('max_ref_ysu', 'f4', 'Time_model')
ref5 = f.createVariable('max_ref_amf1', 'f4', 'Time_amf1')
ref6 = f.createVariable('max_ref_amf1'+mins, 'f4', 'Time_amf1')

time_model[:] = time
time_amf1[:] = timekazr
ref1[:] = np.nanmax(Z1, axis=1)
ref2[:] = np.nanmax(Z2, axis=1)
ref3[:] = np.nanmax(Z3, axis=1)
ref4[:] = np.nanmax(Z4, axis=1)
ref5[:] = np.nanmax(Z, axis=1)
ref6[:] = np.nanmax(Z_ave, axis=1)
f.close()
